-
-
Notifications
You must be signed in to change notification settings - Fork 74
Feat: Handle Adjoints through Initialization #1168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
I wanted to ask whether it is preferred to retain |
LinearSolve.jl should be faster across the board? It depends a bit on the CPU architecture since it depends on whether it guesses the right LU correctly, |
Note that with the latest MTK update, there is now an |
src/concrete_solve.jl
Outdated
@@ -425,6 +425,21 @@ function DiffEqBase._concrete_solve_adjoint( | |||
save_end = true, kwargs_fwd...) | |||
end | |||
|
|||
# Get gradients for the initialization problem if it exists | |||
igs = if _prob.f.initialization_data.initializeprob != nothing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be before the solve, since you can use the initialization solution from here in the remake
s of 397-405 in order to set new u0
and p
and thus skip running the initialization a second time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How can I indicate to solve
to avoid running initialization?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
initializealg = NoInit()
. Should probably just do CheckInit()
for safety but either is fine.
src/steadystate_adjoint.jl
Outdated
@@ -103,15 +102,18 @@ end | |||
else | |||
if linsolve === nothing && isempty(sensealg.linsolve_kwargs) | |||
# For the default case use `\` to avoid any form of unnecessary cache allocation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I don't know about that comment. I think it's just old. (a) \
always allocates because it uses lu
instead of lu!
, so it's re-allocating the while matrix which is larger than any LinearSolve allocation, and (b) we have since 2023 setup tests on StaticArrays, so the immutable path is non-allocating. I don't think (b) was true when this was written.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So glad we can remove this branch altogether.
src/concrete_solve.jl
Outdated
iprob = _prob.f.initialization_data.initializeprob | ||
ip = parameter_values(iprob) | ||
itunables, irepack, ialiases = canonicalize(Tunable(), ip) | ||
igs, = Zygote.gradient(ip) do ip |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This gradient isn't used? I think this would go into the backpass and if I'm thinking clearly, the resulting return is dp .* igs
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not yet. These gradients are currently against the parameters of the initialization problem, not the system exactly. And the mapping between the two is ill defined, so we cannot simply accum
I spoke with @AayushSabharwal about a way to map, it seems initialization_data.intializeprobmap
might have some support to return the correctly shaped vector, but there are cases where we cannot know the ordering of dp
either.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's another subtlety. I am not sure we haven't missed some part of the cfg by manually handling accumulation of gradients. Or any transforms we might need to calculate gradients for. The regular AD graph building typically took care of these details for us, but in this case we would need to worry about incorrect gradients manually
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yes, you need to use the initializeprobmap https://github.com/SciML/SciMLBase.jl/blob/master/src/initialization.jl#L268 to map it back to the shape of the initial parameters.
but there are cases where we cannot know the ordering of dp either.
p and dp just need the same ordering, so initializeprobmap should do the trick.
There's another subtlety. I am not sure we haven't missed some part of the cfg by manually handling accumulation of gradients. Or any transforms we might need to calculate gradients for. The regular AD graph building typically took care of these details for us, but in this case we would need to worry about incorrect gradients manually
This is the only change to (u0,p)
before solving, so this would account for it, given initializeprobmap
is just an index map so an identity function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed this occurance in 95ebbf3 to check if this is correct. Will need to work around the global
call
Trying to use the initialization end to end caused gradients against parameters to get dropped. https://github.com/DhairyaLGandhi/SciMLBase.jl/tree/dg/nonlinear is a WIP branch which adds adjoints to the |
src/concrete_solve.jl
Outdated
new_u0, new_p, _ = SciMLBase.get_initial_values(new_prob, new_prob, new_prob.f, SciMLBase.OverrideInit(), Val(true); | ||
abstol = 1e-6, | ||
reltol = 1e-6, | ||
sensealg = SteadyStateAdjoint(autojacvec = ZygoteVJP())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't default to ZygoteVJP. Should use the autojacvec of the ODE
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed in 9a8a845
I could use some understanding of how to handle initialization when MTK analytically solves the problem, and removes all the unknowns. In that case |
@@ -402,8 +414,8 @@ function get_paramjac_config(autojacvec::ReverseDiffVJP, p, f, y, _p, _t; | |||
# because hasportion(Tunable(), NullParameters) == false | |||
__p = p isa SciMLBase.NullParameters ? _p : | |||
SciMLStructures.replace(Tunable(), p, _p) | |||
tape = ReverseDiff.GradientTape((y, __p, [_t])) do u, p, t | |||
vec(f(u, p, first(t))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same down here?
src/concrete_solve.jl
Outdated
abstol = 1e-6, | ||
reltol = 1e-6, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These don't make sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, these should probably inherit from kwargs
or be set up to some default. Note that we must specify a tol for this dispatch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed in 984c2ce
src/concrete_solve.jl
Outdated
@@ -412,16 +415,40 @@ function DiffEqBase._concrete_solve_adjoint( | |||
Base.diff_names(Base._nt_names(values(kwargs)), | |||
(:callback_adj, :callback))}(values(kwargs)) | |||
isq = sensealg isa QuadratureAdjoint | |||
|
|||
igs, new_u0, new_p = if _prob.f.initialization_data !== nothing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also needs to check that initializealg is not set, is the default, or is using OverrideInit. Should test this is not triggered with say manual BrownBasicInit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. My understanding was that OverrideInit was what we strictly needed. We can check for BrownBasicInit/ defaults here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There doesn't seem to be a method which can take a BrownFullBasicInit()
. I get a MethodError
:
ERROR: MethodError: no method matching get_initial_values(::ODEProblem{…}, ::ODEProblem{…}, ::ODEFunction{…}, ::BrownFullBasicInit{…}, ::Val{…}; sensealg::SteadyStateAdjoint{…}, nlsolve_alg::Nothing)
Closest candidates are:
get_initial_values(::Any, ::Any, ::Any, ::NoInit, ::Any; kwargs...)
@ SciMLBase ~/Downloads/arpa/jsmo/t2/SciMLBase.jl/src/initialization.jl:282
get_initial_values(::Any, ::Any, ::Any, ::SciMLBase.OverrideInit, ::Union{Val{true}, Val{false}}; nlsolve_alg, abstol, reltol, kwargs...)
@ SciMLBase ~/Downloads/arpa/jsmo/t2/SciMLBase.jl/src/initialization.jl:224
get_initial_values(::SciMLBase.AbstractDEProblem, ::SciMLBase.DEIntegrator, ::Any, ::CheckInit, ::Union{Val{true}, Val{false}}; abstol, kwargs...)
@ SciMLBase ~/Downloads/arpa/jsmo/t2/SciMLBase.jl/src/initialization.jl:161
...
Only CheckInit
, NoInit
, and OverrideInit
have dispatches.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having chatted with @AayushSabharwal on this, it seems like BrownBasic and ShampineCollocation do not yet have a path through get_initial_values
and that would need to be fixed in OrdianryDiffEq. Further, as SciMLSensitivity does not depend on OrdinaryDiffEq, it cannot check for whether there is a default initialisation.since those are defined there. Depending on it also seems like a pretty big hammer for a dep.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What would be the best course of action here? Seems like supporting BrownBasicInit is a dispatch that will automatically be utilised when it is moved into SciMLBase.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I mean BrownBasicInit should not be taking this path. But that's a problem because then they will be disabled in the next stage below, and that needs to be accounted for. This dispatch is already built and setup for BrownBasicInit and there are tests on that.
src/concrete_solve.jl
Outdated
if sensealg isa BacksolveAdjoint | ||
sol = solve(_prob, alg, args...; save_noise = true, | ||
sol = solve(_prob, alg, args...; initializealg = SciMLBase.NoInit(), save_noise = true, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should only noinit if the previous case was ran. Won't this right now break the brownbasic tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there was no intialization data, it won't have ran the initialization problem at all.
If I can genetically ignore handling initializealg
and pass it directly to get_initial_values
, that would be good. Then I can also pass NoInit
here genetically.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No that is not correct. If there was no initialization data then it will use the built in initialization, defaulting to BrownBasicInit. It's impossible for a DAE solver to generally work without running initialization of some form, the MTK one is just a new specialized one but there has always been a numerical one in the solver. And if it hits that case, this code will now disable that.
https://github.com/SciML/SciMLSensitivity.jl/blob/master/test/adjoint.jl#L952-L978 this code will hit that. I think it's not failing because it's not so pronounced here. You might want to change that test to https://github.com/SciML/SciMLSensitivity.jl/blob/master/test/adjoint.jl#L975C5-L975C69 prob_singular_mm = ODEProblem(f, [1.0, 0.0, 1.0], (0.0, 100), p)
and it would pass before and fail now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right of course for the DAEs, but since BrownBasicInit
is defined in OrdinaryDiffEq, and this package does not depend on it, I need a way for us to be able to dispatch to it. So if I understand the comment from earlier, we need a check for the default initialization, and add a branch that solves for that prob
, and collect all the outputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both BrownBasicInit
and OrdinaryDiffEqCore.DefaultInit
require us to depend on a whole package for the default dispatch. Can it be exposed as a dispatch of get_initial_values
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh 😅 . That case was too simple, MTK turns it into an ODE. Let's make it a DAE.
@parameters σ ρ β A[1:3]
@variables x(t) y(t) z(t) w(t) w2(t)
eqs = [D(D(x)) ~ σ * (y - x),
D(y) ~ x * (ρ - z) - y,
D(z) ~ x * y - β * z,
w ~ x + y + z + 2 * β
0 ~ x^2 + y^2 - w2^2
]
@mtkbuild sys = ODESystem(eqs, t)
That should make it so that it eliminates the w
term, but doesn't eliminate the w2
term. The DAE check is on the w2 term, the observed handling check is on the w term.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That will need to change the integrator to Rodas5P
, Tsit5 will not be compatible with this form.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For SDEs, we will just need to make it compatible with BrownBasicInit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I was so confused why it worked out, I see the InitialFailure
now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay good. Yeah because MTK is too smart and makes lots of simple examples not DAEs 😅. But now you got the DAE, and if not running the built in init then you get the error I was expecting. The fix is that it needs to run brownbasic before solving for the same reason reverse needs to. Good we worked out a test for this
src/concrete_solve.jl
Outdated
elseif isnothing(_out) | ||
_out | ||
else | ||
@. _out[_save_idxs] = Δ.u[_save_idxs] | ||
end | ||
end | ||
dp = adjoint_sensitivities(sol, alg; sensealg = sensealg, dgdu = df) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When the new reverse ode is built it needs to drop the initial eqs but still keep the dae constraints. It can brownbasic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a way to drop the initial eqs after its solved? The assumption was since we run with NoInit
, no initialization is run post the first call to get_initial_values
and we accumulate those gradients independently of the adaptive solve.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But the reverse pass needs to run with some form of initialization or the starting algebraic conditions may not be satisfied. Don't run this one with NoInit(), that would be prone to hiding issue. For this one, at most CheckInit(), but I'm saying that BrownBasicInit() is likely the one justified here since the 0 initial condition is only true on the differential variables, while the algebraic variable initial conditions will be unknown, but the Newton solve will have zero derivative because all of the inputs are just Newton guesses, so BrownBasic will work out for the reverse. We should probably hardcode that since it's always the solution there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, that will require us to add an OrdinaryDiffEqCore dep in this package. I will add that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the 0 derivative also applicable to parameters? Or only the unknowns?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Its applicable to all Newton guess values. There is no parameter init going on to reverse so it's only for algebraic conditions so it's only Newton guesses.
… to ODE form; check grads
Co-authored-by: Christopher Rackauckas <[email protected]>
test/mtk.jl
Outdated
dmtk_incorrectu0, = Zygote.gradient(mtkparams_incorrectu0) do p | ||
new_sol = solve(prob_incorrectu0, Rodas5P(); p = p, initializealg = BrownFullBasicInit(), sensealg, abstol = 1e-6, reltol = 1e-3) | ||
Zygote.ChainRules.ChainRulesCore.ignore_derivatives() do | ||
@test new_sol.retcode == SciMLBase.ReturnCode.Success | ||
@test all(isapprox.(new_sol[x + y + z + 2 * β - w], 0, atol = 1e-12)) | ||
@test all(isapprox.(new_sol[x^2 + y^2 - w2^2], 0, atol = 1e-5, rtol = 1e0)) | ||
end | ||
mean(abs.(new_sol[sys.x] .- gt)) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dmtk_incorrectu0, = Zygote.gradient(mtkparams_incorrectu0) do p | |
new_sol = solve(prob_incorrectu0, Rodas5P(); p = p, initializealg = BrownFullBasicInit(), sensealg, abstol = 1e-6, reltol = 1e-3) | |
Zygote.ChainRules.ChainRulesCore.ignore_derivatives() do | |
@test new_sol.retcode == SciMLBase.ReturnCode.Success | |
@test all(isapprox.(new_sol[x + y + z + 2 * β - w], 0, atol = 1e-12)) | |
@test all(isapprox.(new_sol[x^2 + y^2 - w2^2], 0, atol = 1e-5, rtol = 1e0)) | |
end | |
mean(abs.(new_sol[sys.x] .- gt)) | |
end | |
dmtk_overrideinit_incorrectu0, = Zygote.gradient(mtkparams_incorrectu0) do p | |
new_sol = solve(prob_incorrectu0, Rodas5P(); p = p, initializealg = OverrideInit(), sensealg, abstol = 1e-6, reltol = 1e-3) | |
Zygote.ChainRules.ChainRulesCore.ignore_derivatives() do | |
@test new_sol.retcode == SciMLBase.ReturnCode.Success | |
@test all(isapprox.(new_sol[x + y + z + 2 * β - w], 0, atol = 1e-12)) | |
@test all(isapprox.(new_sol[x^2 + y^2 - w2^2], 0, atol = 1e-5, rtol = 1e0)) | |
end | |
mean(abs.(new_sol[sys.x] .- gt)) | |
end | |
dmtk_incorrectu0, = Zygote.gradient(mtkparams_incorrectu0) do p | |
new_sol = solve(prob_incorrectu0, Rodas5P(); p = p, initializealg = BrownFullBasicInit(), sensealg, abstol = 1e-6, reltol = 1e-3) | |
Zygote.ChainRules.ChainRulesCore.ignore_derivatives() do | |
@test new_sol.retcode == SciMLBase.ReturnCode.Success | |
@test all(isapprox.(new_sol[x + y + z + 2 * β - w], 0, atol = 1e-12)) | |
@test all(isapprox.(new_sol[x^2 + y^2 - w2^2], 0, atol = 1e-5, rtol = 1e0)) | |
end | |
mean(abs.(new_sol[sys.x] .- gt)) | |
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should test OverrideInit and DefaultInit
I added the BrownBasic for the reverse and setup the tests so that there's a version that would capture that issue of the MTK init in the reverse pass. It almost certainly needs the fix to SciML/ModelingToolkit.jl#3570 to pass though, so it will likely fail at first and @AayushSabharwal this is a reason to prioritize getting that one completed. But when that is merged and this passes, then I think this is good to go. |
Checklist
contributor guidelines, in particular the SciML Style Guide and
COLPRAC.
Additional context
MTK and SciML construct an initialization problem before starting the time stepping to ensure the starting values of the unknowns and parameters adhere to any constraints needed for the system. This PR adds handling for adjoint sensitivities of the NonlinearProblem, NonlinearSquaresProblem, SCCNonlinearProblem etc.
I am opening this to get some feedback regarding how we can accumulate gradients correctly. I have also included a test case for a DAE which I will update to use the values out of SciMLSensitivity.
Add any other context about the problem here.
Currently the gradients get calculated but don't get accumulated, we need to be able to update the gradients for the parameters. Since this is a manual dispatch, the usual graph building in AD is bypassed, and we need to handle this manually. Ideally, we should make it so the cfg itself includes the initialization so we would not have gotten incorrect gradients in the first place 😅 We are also forced to use a LinearProblem instead of
\
because it cannot handle singular jacobians.cc @ChrisRackauckas